// @HEADER
//*********************************************************************//
//  SiYuan: A numerical PDE solver                                     //
//  Copyright (2022) YUAN Xi                                           //
//  This Software is released under the BSD 2-Clause license detailed  //
//  in the file "LICENSE" in the top-level SiYuan directory            //
//*********************************************************************//
// @HEADER

#ifndef __Response_Functional_impl_hpp__
#define __Response_Functional_impl_hpp__

#include "Teuchos_Comm.hpp"
#include "Teuchos_CommHelpers.hpp"
#include "Teuchos_dyn_cast.hpp"

#include "Sacado_Traits.hpp"
#include "Panzer_IntegrationRule.hpp"
#include "Panzer_PhysicsBlock.hpp"
#include "Panzer_Integrator_Scalar.hpp"
#include "Panzer_ResponseScatterEvaluator_Functional.hpp"
#include "Panzer_Response_Functional.hpp"
#include "Panzer_BlockedDOFManager.hpp"
#include "Panzer_GlobalIndexer_Utilities.hpp"

namespace SiYuan {

template <typename EvalT>
void Response_Functional<EvalT>::
scatterResponse()
{
  double locValue = Sacado::ScalarValue<ScalarT>::eval(value);
  double glbValue = 0.0;

  // do global summation
  Teuchos::reduceAll(*this->getComm(), Teuchos::REDUCE_SUM, static_cast<Thyra::Ordinal>(1), &locValue,&glbValue);

  value = glbValue;
  this->getThyraVector()[0] = glbValue;
}

template < >
void Response_Functional<panzer::Traits::Jacobian>::
scatterResponse()
{
  using Teuchos::rcp_dynamic_cast;

  Teuchos::RCP<Thyra::MultiVectorBase<double> > dgdx_unique = getDerivative();

  // if its null, don't do anything
  if(dgdx_unique==Teuchos::null)
    return;

  uniqueContainer_ = linObjFactory_->buildLinearObjContainer();
  Teuchos::rcp_dynamic_cast<panzer::ThyraObjContainer<double> >(uniqueContainer_)->set_x_th(dgdx_unique->col(0));

  linObjFactory_->ghostToGlobalContainer(*ghostedContainer_,*uniqueContainer_,panzer::LinearObjContainer::X);

  uniqueContainer_ = Teuchos::null;
}

#ifdef Panzer_BUILD_HESSIAN_SUPPORT
template < >
void Response_Functional<panzer::Traits::Hessian>::
scatterResponse()
{
  using Teuchos::rcp_dynamic_cast;

  Teuchos::RCP<Thyra::MultiVectorBase<double> > dgdx_unique = getDerivative();

  // if its null, don't do anything
  if(dgdx_unique==Teuchos::null)
    return;

  uniqueContainer_ = linObjFactory_->buildLinearObjContainer();
  Teuchos::rcp_dynamic_cast<panzer::ThyraObjContainer<double> >(uniqueContainer_)->set_x_th(dgdx_unique->col(0));

  linObjFactory_->ghostToGlobalContainer(*ghostedContainer_,*uniqueContainer_,panzer::LinearObjContainer::X);

  uniqueContainer_ = Teuchos::null;
}
#endif

template < >
void Response_Functional<panzer::Traits::Tangent>::
scatterResponse()
{
  const int n = value.size();
  const int num_deriv = this->numDeriv();
  TEUCHOS_ASSERT(n == 0 || n == num_deriv);
  ScalarT glbValue = ScalarT(num_deriv, 0.0);

  // do global summation -- it is possible to do the reduceAll() on the Fad's directly, but it is somewhat
  // complicated for DFad (due to temporaries that might get created).  Since this is just a sum, it is
  // easier to do the reduction for each value and derivative component.
  Teuchos::reduceAll(*this->getComm(), Teuchos::REDUCE_SUM, Thyra::Ordinal(1), &value.val(), &glbValue.val());
  if (num_deriv > 0)
    Teuchos::reduceAll(*this->getComm(), Teuchos::REDUCE_SUM, Thyra::Ordinal(n), value.dx(),  &glbValue.fastAccessDx(0));

  value = glbValue;
  Thyra::ArrayRCP< Thyra::ArrayRCP<double> > deriv = this->getThyraMultiVector();
  for (int i=0; i<num_deriv; ++i)
    deriv[i][0] = glbValue.dx(i);
}

// Do nothing unless derivatives are actually required
template <typename EvalT>
void Response_Functional<EvalT>::
setSolnVectorSpace(const Teuchos::RCP<const Thyra::VectorSpaceBase<double> > & /* soln_vs */) { }

// derivatives are required for
template < >
void Response_Functional<panzer::Traits::Jacobian>::
setSolnVectorSpace(const Teuchos::RCP<const Thyra::VectorSpaceBase<double> > & soln_vs)
{
  setDerivativeVectorSpace(soln_vs);
}

// derivatives are required for
#ifdef Panzer_BUILD_HESSIAN_SUPPORT
template < >
void Response_Functional<panzer::Traits::Hessian>::
setSolnVectorSpace(const Teuchos::RCP<const Thyra::VectorSpaceBase<double> > & soln_vs)
{
  setDerivativeVectorSpace(soln_vs);
}
#endif

template <typename EvalT>
void Response_Functional<EvalT>::
buildAndRegisterEvaluators(const std::string & responseName,
                           PHX::FieldManager<panzer::Traits> & fm,
                           const panzer::PhysicsBlock & physicsBlock,
                           const Teuchos::ParameterList & /* user_data */) const
{
   // build integration evaluator (integrate over element)
   if(requiresCellIntegral_) {
     std::string field = (quadPointField_=="" ? responseName : quadPointField_);

     // build integration rule to use in cell integral
     Teuchos::RCP<panzer::IntegrationRule> ir 
        = Teuchos::rcp(new panzer::IntegrationRule(cubatureDegree_,physicsBlock.cellData()));

     Teuchos::ParameterList pl;
     pl.set("Integral Name", field);
     pl.set("Integrand Name",field);
     pl.set("IR",ir);

     Teuchos::RCP<PHX::Evaluator<panzer::Traits> > eval 
         = Teuchos::rcp(new panzer::Integrator_Scalar<EvalT,panzer::Traits>(pl));
 
     this->template registerEvaluator<EvalT>(fm, eval);
   }


   // build scatter evaluator
   {
     Teuchos::RCP<panzer::FunctionalScatterBase> scatterObj;
     if(linObjFactory_!=Teuchos::null) {

        TEUCHOS_ASSERT(linObjFactory_->getDomainGlobalIndexer()!=Teuchos::null);

        auto ugi = Teuchos::rcp_dynamic_cast<const panzer::GlobalIndexer>(linObjFactory_->getDomainGlobalIndexer());
        auto bugi = Teuchos::rcp_dynamic_cast<const panzer::BlockedDOFManager>(linObjFactory_->getDomainGlobalIndexer());

        if(ugi!=Teuchos::null) {
          std::vector<Teuchos::RCP<const panzer::GlobalIndexer> > ugis; 
          ugis.push_back(ugi);

          scatterObj = Teuchos::rcp(new panzer::FunctionalScatter<panzer::LocalOrdinal,panzer::GlobalOrdinal>(ugis));
        }
        else if(bugi!=Teuchos::null) {
          scatterObj = Teuchos::rcp(new panzer::FunctionalScatter<panzer::LocalOrdinal,panzer::GlobalOrdinal>(nc2c_vector(bugi->getFieldDOFManagers())));
        }
        else {
          TEUCHOS_ASSERT(false); // no real global indexer to use
        }
     }

     std::string field = (quadPointField_=="" ? responseName : quadPointField_);

     // build useful evaluator
     Teuchos::RCP<PHX::Evaluator<panzer::Traits> > eval = Teuchos::rcp(
         new panzer::ResponseScatterEvaluator_Functional<EvalT,panzer::Traits>(field,responseName,physicsBlock.cellData(),scatterObj) );

     this->template registerEvaluator<EvalT>(fm, eval);

     // require last field
     fm.template requireField<EvalT>(*eval->evaluatedFields()[0]);
   }
}

template <typename EvalT>
bool Response_Functional<EvalT>::
typeSupported() const
{
  if(   PHX::print<EvalT>()==PHX::print<panzer::Traits::Residual>() ||
        PHX::print<EvalT>()==PHX::print<panzer::Traits::Tangent>()
    )
    return true;

  if(PHX::print<EvalT>()==PHX::print<panzer::Traits::Jacobian>())
    return linObjFactory_!=Teuchos::null;

#ifdef Panzer_BUILD_HESSIAN_SUPPORT
  if(PHX::print<EvalT>()==PHX::print<panzer::Traits::Hessian>()) {
    return linObjFactory_!=Teuchos::null;
  }
#endif

  return false;
}

}

#endif
